%{
Source code for the paper 
"Revisiting High-resolution ODEs for Faster Convergence Rates" 

Submitted to ICLR-2024


%}

%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%Add "Source" folder to the matlab directory!
%%%%%%%%%%%%%%%%%%%%%%%%%%%%



clc
clear

d=10;% Data dimension
n=1000;% Number of samples
mu=0.001; % Strong convexity parameter
iter = 2e3; % Iteration




%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%For exact results as the paper uncomment the following line.
%rng(20)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Selecting the desired experiment: 
% 1 for Continuous time comparisonfor QHM experiments 
% 2 for TM experiment
% 3 for for QHM experiments 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%




experiment_number = 1;

if experiment_number == 1

%% Function Selection
% 1 for Regularized binary classification
% 2 for one-dimensional strongly convex function

L=10;
func_select = 1;
if func_select == 1
    x_star = randn(d,1)+1;
    X = randn(n,d);% random data
    Y= (1./(1+exp(-X*x_star))>= rand(n,1))*2-1; % random labels
    f=@(x,X,Y) 1/n*sum(log(1+exp(-Y.*X*x)))+ mu/2*norm(x)^2;% Function
    f(x_star,X,Y)
    grad=@(x,X,Y) sum(-1/n.*(Y.*X)'.*(exp(-Y.*(X*x)))'./(1+exp(-Y.*(X*x)))',2)+mu*x; %Gradient
    hess=@(x,X,Y) (Y.*X)'*(Y.*X)*sum(1/n.*(exp(-Y.*(X*x)))'./((1+exp(-Y.*(X*x))).^2)',2)+mu; %Hessian
    L = max(Y.^2.*diag(X*X'))+mu;% Lipschitz constant
elseif func_select == 2
    d=1;
    X=0;
    Y=0;
    f = @(x,X,Y) 4*(L-mu)*log(1+exp(-x))+mu/2*x^2; % Function
    grad= @(x,X,Y) -4*(L-mu).*exp(-x)./(1+exp(-x))+mu.*x; % Gradient
    hess= @(x,X,Y) 4*(L-mu)*exp(-x)./(1+exp(-x))^2+mu;% Hessian
end

x0 = randn(d,1);
%% Nesterov
x = x0;
y = x;
b = 1+(sqrt(L/mu)-1)/(sqrt(L/mu)+1);
for i=1:iter
        y_save = y;
        func_val(i) = f(x,X,Y);
        y = x - 1/L * grad(x,X,Y);
        x = b * y + (1-b)*y_save;
end
% x_star changes due to regularization
x_star = x;


error_NAG = func_val-f(x_star,X,Y);

%% QHM_GM^2

a=1/4;
q=sqrt(a*mu);
s=3/(4*L);
b= (1-q*sqrt(s))/(1+q*sqrt(s));

x = x0;
g = x;
for i=1:iter

    error_QHM(i) = f(x,X,Y)-f(x_star,X,Y);

    g = b*g + grad(x,X,Y);
    x = x -s*(1-a)*grad(x,X,Y)-s*a*g;
    rate_QHM_us(i) = (f(x0,X,Y)-f(x_star,X,Y))*(1-sqrt(a*mu*s))^i;

end


%% QHM_GM


a=1/2;
s=1/(4*L);
b= (1-2*sqrt(mu*s));

x = x0;
g = x;
for i=1:iter
    error_QHM_zhang(i) = f(x,X,Y)-f(x_star,X,Y);

    g = b*g + grad(x,X,Y);
    x = x -s*(1-a)*grad(x,X,Y)-s*a*g;
    rate_QHM_zhang(i) = (f(x0,X,Y)-f(x_star,X,Y))*(1+a*sqrt(mu*s)/10)^(-i);

end


Plt_matrix = [error_NAG',error_QHM',error_QHM_zhang',rate_QHM_us',rate_QHM_zhang'];

Fig3(Plt_matrix)

elseif experiment_number == 2
       
%% ODE with d=1
L=10;
d =1;
% Function Selection
% 1 for Regularized binary classification
% 2 for one-dimensional strongly convex function
func_select = 2;
x0 =randn(d,1);
if func_select == 1
    
    x_star = randn(d,1);
    X = randn(n,d);%random data
    Y= (1./(1+exp(-X*x_star))>= rand(n,1))*2-1;%random labels
    f=@(x,X,Y) 1/n*sum(log(1+exp(-Y.*X*x)))+ mu/2*norm(x)^2;%Function
    grad=@(x,X,Y) sum(-1/n.*(Y.*X)'.*(exp(-Y.*(X*x)))'./(1+exp(-Y.*(X*x)))',2)+mu*x;%Gradient
    hess=@(x,X,Y) (Y.*X)'*(Y.*X)*sum(1/n.*(exp(-Y.*(X*x)))'./((1+exp(-Y.*(X*x))).^2)',2)+mu;%Hessian
    L = max(Y.^2.*diag(X*X'))+mu;%Lipschitz constant
elseif func_select == 2
    d=1;
    X=0;
    Y=0;
    f = @(x,X,Y) 4*(L-mu)*log(1+exp(-x))+mu/2*x^2;%Function
    grad= @(x,X,Y) -4*(L-mu).*exp(-x)./(1+exp(-x))+mu.*x;%Gradient
    hess= @(x,X,Y) 4*(L-mu)*exp(-x)./(1+exp(-x))^2+mu;%Hessian
end
%TM method parameters
rho=1-sqrt(1/(L/mu));
alpha=(1+rho)/L;
beta=rho^2/(2-rho);
gamma=rho^2/((1+rho)*(2-rho));
delta=rho^2/(1-rho^2);
M = ((1-beta)/(sqrt(alpha)*(1+beta)))^2;
%Solving TM ODE
[t0 , y0] = ode15s(@(t,y)TM_method_ODE(L,mu,grad,X,Y,hess,t,y),(1:iter), [x0,0]);


for i=1:iter
r(i)=f(y0(i,1),X,Y);
end







%% Nesterov
x = x0;
y = x;
b = 1+(sqrt(L/mu)-1)/(sqrt(L/mu)+1);
for i=1:iter
        y_save = y;
        func_val(i) = f(x,X,Y);
        y = x - 1/L * grad(x,X,Y);
        x = b * y + (1-b)*y_save;
end
% x_star changes due to regularization
x_star = x;


error_NAG = func_val-f(x_star,X,Y);
%% TMM ODE and bounds

m = gamma*sqrt(alpha)*(1+sqrt(M*alpha));
n = 4/3*sqrt(M);

for i=1:iter
    rate_sun(i) = (1.5/alpha*norm(x0-x_star)^2)*exp(-sqrt(M)/2*i);%TM ODE rate ([Sun et al., 2020], A high-resolution ODE for modelling fastest known globally convergent method ...)
    rate_new(i) = (f(x0,X,Y)-f(x_star,X,Y)+(4/3)^2*M/(2*(1-2/3*gamma*sqrt(M*alpha)*(1+sqrt(M*alpha))))*norm(x0+m/n*grad(x0,X,Y)-x_star)^2)*exp(-sqrt(M)*2/3*i);% TM ODE rate (this work)
end

Plt_matrix = [rate_sun',rate_new'];

Fig2(t0,r'-f(x_star,X,Y)',Plt_matrix)

elseif experiment_number == 3


%% For Figure 1


%% This work
[t,y] = ode15s(@momentum_flow3,[0.01 100],[10;0]);


%% others


[t1,y1] = ode15s(@momentum_flow_Hessian,[0.01 100],[10;0]);%zhang

[t2,y2] = ode15s(@momentum_flow_Hessian_correction,[0.01 100],[10;0]);%Nesterov

[t3,y3] = ode15s(@momentum_flow_Resolution,[0.01 100],[10;0]);%SHi


figure 
plot(t1,y1)
hold on
plot(t2,y2)
plot(t3,y3)
plot(t,y)
end